Statistical learning: classification and cross-validation

MACS 30500 University of Chicago

Should I Have a Cookie?

Interpreting a decision tree

A more complex tree

A more complexier tree

Benefits/drawbacks to decision trees

  • Easy to explain
  • Easy to interpret/visualize
  • Good for qualitative predictors
  • Lower accuracy rates
  • Non-robust

Random forests

Sampling

(numbers <- seq(from = 1, to = 10))
##  [1]  1  2  3  4  5  6  7  8  9 10
# sample without replacement
rerun(5, sample(numbers, replace = FALSE))
## [[1]]
##  [1]  6  4  1 10  9  7  5  2  3  8
## 
## [[2]]
##  [1]  9  8  7  1  4  5  3 10  6  2
## 
## [[3]]
##  [1]  2  4  7  1 10  8  3  5  6  9
## 
## [[4]]
##  [1]  4  6  3  1  7 10  5  8  2  9
## 
## [[5]]
##  [1]  8  6  7  5  9  3 10  1  4  2
# sample with replacement
rerun(5, sample(numbers, replace = TRUE))
## [[1]]
##  [1]  5  4  2  3 10  1  5  3  5  8
## 
## [[2]]
##  [1]  8 10  2  9  9 10  9  1  5  3
## 
## [[3]]
##  [1]  6  3  1 10  4  3  7  8  1  7
## 
## [[4]]
##  [1]  2  3  7  2  4 10  8  5  4  8
## 
## [[5]]
##  [1]  3  3  4  9  9  1 10  2  1  7

Random forests

  • Bootstrapping
  • Reduces variance
  • Bagging
  • Random forest
    • Reliability

Estimating statistical models using caret

  • Not part of tidyverse (yet)
  • Aggregator of hundreds of statistical learning algorithms
  • Provides a single unified interface to disparate range of functions
    • Similar to scikit-learn for Python

train()

library(caret)

titanic_clean <- titanic %>%
  filter(!is.na(Survived), !is.na(Age))

caret_glm <- train(Survived ~ Age, data = titanic_clean,
                   method = "glm",
                   family = binomial,
                   trControl = trainControl(method = "none"))
summary(caret_glm)
## 
## Call:
## NULL
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -1.1488  -1.0361  -0.9544   1.3159   1.5908  
## 
## Coefficients:
##             Estimate Std. Error z value Pr(>|z|)  
## (Intercept) -0.05672    0.17358  -0.327   0.7438  
## Age         -0.01096    0.00533  -2.057   0.0397 *
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 964.52  on 713  degrees of freedom
## Residual deviance: 960.23  on 712  degrees of freedom
## AIC: 964.23
## 
## Number of Fisher Scoring iterations: 4

Estimating a random forest

age_sex_rf <- train(Survived ~ Age + Sex, data = titanic_rf_data,
                   method = "rf",
                   ntree = 200,
                   trControl = trainControl(method = "oob"))
## note: only 1 unique complexity parameters in default grid. Truncating the grid to 1 .
age_sex_rf
## Random Forest 
## 
## 714 samples
##   2 predictor
##   2 classes: 'Died', 'Survived' 
## 
## No pre-processing
## Resampling results:
## 
##   Accuracy   Kappa    
##   0.7507003  0.4734426
## 
## Tuning parameter 'mtry' was held constant at a value of 2

Structure of train() object

## List of 24
##  $ method      : chr "rf"
##  $ modelInfo   :List of 15
##  $ modelType   : chr "Classification"
##  $ results     :'data.frame':    1 obs. of  3 variables:
##  $ pred        : NULL
##  $ bestTune    :'data.frame':    1 obs. of  1 variable:
##  $ call        : language train.formula(form = Survived ~ Age + Sex, data = titanic_rf_data,      method = "rf", ntree = 200, trControl = t| __truncated__
##  $ dots        :List of 1
##  $ metric      : chr "Accuracy"
##  $ control     :List of 26
##  $ finalModel  :List of 23
##   ..- attr(*, "class")= chr "randomForest"
##  $ preProcess  : NULL
##  $ trainingData:Classes 'tbl_df', 'tbl' and 'data.frame':    714 obs. of  3 variables:
##  $ resample    : NULL
##  $ resampledCM : NULL
##  $ perfNames   : chr [1:2] "Accuracy" "Kappa"
##  $ maximize    : logi TRUE
##  $ yLimits     : NULL
##  $ times       :List of 3
##  $ levels      : chr [1:2] "Died" "Survived"
##   ..- attr(*, "ordered")= logi FALSE
##  $ terms       :Classes 'terms', 'formula'  language Survived ~ Age + Sex
##   .. ..- attr(*, "variables")= language list(Survived, Age, Sex)
##   .. ..- attr(*, "factors")= int [1:3, 1:2] 0 1 0 0 0 1
##   .. .. ..- attr(*, "dimnames")=List of 2
##   .. ..- attr(*, "term.labels")= chr [1:2] "Age" "Sex"
##   .. ..- attr(*, "order")= int [1:2] 1 1
##   .. ..- attr(*, "intercept")= int 1
##   .. ..- attr(*, "response")= int 1
##   .. ..- attr(*, ".Environment")=<environment: R_GlobalEnv> 
##   .. ..- attr(*, "predvars")= language list(Survived, Age, Sex)
##   .. ..- attr(*, "dataClasses")= Named chr [1:3] "factor" "numeric" "factor"
##   .. .. ..- attr(*, "names")= chr [1:3] "Survived" "Age" "Sex"
##  $ coefnames   : chr [1:2] "Age" "Sexmale"
##  $ contrasts   :List of 1
##  $ xlevels     :List of 1
##  - attr(*, "class")= chr [1:2] "train" "train.formula"

Model statistics

age_sex_rf$finalModel
## 
## Call:
##  randomForest(x = x, y = y, ntree = 200, mtry = param$mtry) 
##                Type of random forest: classification
##                      Number of trees: 200
## No. of variables tried at each split: 2
## 
##         OOB estimate of  error rate: 24.23%
## Confusion matrix:
##          Died Survived class.error
## Died      357       67   0.1580189
## Survived  106      184   0.3655172

Results of a single tree

randomForest::getTree(age_sex_rf$finalModel, labelVar = TRUE)
##     left daughter right daughter split var split point status prediction
## 1               2              3   Sexmale       0.500      1       <NA>
## 2               4              5       Age      14.750      1       <NA>
## 3               6              7       Age       3.500      1       <NA>
## 4               8              9       Age       8.000      1       <NA>
## 5              10             11       Age      36.500      1       <NA>
## 6              12             13       Age       2.500      1       <NA>
## 7              14             15       Age      25.500      1       <NA>
## 8              16             17       Age       3.500      1       <NA>
## 9              18             19       Age      12.000      1       <NA>
## 10             20             21       Age      31.500      1       <NA>
## 11             22             23       Age      37.500      1       <NA>
## 12             24             25       Age       1.500      1       <NA>
## 13              0              0      <NA>       0.000     -1   Survived
## 14             26             27       Age      13.000      1       <NA>
## 15             28             29       Age      53.000      1       <NA>
## 16             30             31       Age       1.375      1       <NA>
## 17             32             33       Age       5.500      1       <NA>
## 18              0              0      <NA>       0.000     -1       Died
## 19             34             35       Age      13.500      1       <NA>
## 20             36             37       Age      30.500      1       <NA>
## 21              0              0      <NA>       0.000     -1   Survived
## 22              0              0      <NA>       0.000     -1       Died
## 23             38             39       Age      48.500      1       <NA>
## 24             40             41       Age       0.915      1       <NA>
## 25              0              0      <NA>       0.000     -1       Died
## 26             42             43       Age      11.500      1       <NA>
## 27             44             45       Age      20.500      1       <NA>
## 28             46             47       Age      48.500      1       <NA>
## 29             48             49       Age      60.500      1       <NA>
## 30              0              0      <NA>       0.000     -1   Survived
## 31             50             51       Age       2.500      1       <NA>
## 32              0              0      <NA>       0.000     -1   Survived
## 33             52             53       Age       6.500      1       <NA>
## 34              0              0      <NA>       0.000     -1   Survived
## 35             54             55       Age      14.250      1       <NA>
## 36             56             57       Age      26.500      1       <NA>
## 37              0              0      <NA>       0.000     -1       Died
## 38             58             59       Age      44.500      1       <NA>
## 39             60             61       Age      56.500      1       <NA>
## 40              0              0      <NA>       0.000     -1   Survived
## 41              0              0      <NA>       0.000     -1   Survived
## 42             62             63       Age       9.500      1       <NA>
## 43              0              0      <NA>       0.000     -1   Survived
## 44             64             65       Age      19.500      1       <NA>
## 45             66             67       Age      23.750      1       <NA>
## 46             68             69       Age      32.250      1       <NA>
## 47             70             71       Age      49.500      1       <NA>
## 48             72             73       Age      59.500      1       <NA>
## 49              0              0      <NA>       0.000     -1       Died
## 50              0              0      <NA>       0.000     -1       Died
## 51              0              0      <NA>       0.000     -1       Died
## 52              0              0      <NA>       0.000     -1   Survived
## 53              0              0      <NA>       0.000     -1   Survived
## 54              0              0      <NA>       0.000     -1       Died
## 55              0              0      <NA>       0.000     -1       Died
## 56             74             75       Age      25.500      1       <NA>
## 57             76             77       Age      27.500      1       <NA>
## 58             78             79       Age      41.500      1       <NA>
## 59             80             81       Age      46.000      1       <NA>
## 60              0              0      <NA>       0.000     -1   Survived
## 61             82             83       Age      57.500      1       <NA>
## 62             84             85       Age       8.000      1       <NA>
## 63              0              0      <NA>       0.000     -1       Died
## 64             86             87       Age      15.500      1       <NA>
## 65              0              0      <NA>       0.000     -1       Died
## 66             88             89       Age      22.500      1       <NA>
## 67             90             91       Age      24.250      1       <NA>
## 68             92             93       Age      28.750      1       <NA>
## 69             94             95       Age      41.500      1       <NA>
## 70              0              0      <NA>       0.000     -1   Survived
## 71             96             97       Age      51.500      1       <NA>
## 72              0              0      <NA>       0.000     -1       Died
## 73              0              0      <NA>       0.000     -1   Survived
## 74             98             99       Age      19.500      1       <NA>
## 75              0              0      <NA>       0.000     -1       Died
## 76              0              0      <NA>       0.000     -1   Survived
## 77            100            101       Age      29.500      1       <NA>
## 78            102            103       Age      40.500      1       <NA>
## 79              0              0      <NA>       0.000     -1   Survived
## 80              0              0      <NA>       0.000     -1       Died
## 81            104            105       Age      47.500      1       <NA>
## 82              0              0      <NA>       0.000     -1       Died
## 83              0              0      <NA>       0.000     -1   Survived
## 84            106            107       Age       5.500      1       <NA>
## 85              0              0      <NA>       0.000     -1   Survived
## 86              0              0      <NA>       0.000     -1       Died
## 87            108            109       Age      16.500      1       <NA>
## 88            110            111       Age      21.500      1       <NA>
## 89              0              0      <NA>       0.000     -1       Died
## 90              0              0      <NA>       0.000     -1       Died
## 91            112            113       Age      24.750      1       <NA>
## 92            114            115       Age      27.500      1       <NA>
## 93            116            117       Age      29.500      1       <NA>
## 94            118            119       Age      33.500      1       <NA>
## 95            120            121       Age      42.500      1       <NA>
## 96            122            123       Age      50.500      1       <NA>
## 97              0              0      <NA>       0.000     -1   Survived
## 98            124            125       Age      18.500      1       <NA>
## 99            126            127       Age      20.500      1       <NA>
## 100           128            129       Age      28.500      1       <NA>
## 101             0              0      <NA>       0.000     -1   Survived
## 102           130            131       Age      39.500      1       <NA>
## 103             0              0      <NA>       0.000     -1       Died
## 104             0              0      <NA>       0.000     -1   Survived
## 105             0              0      <NA>       0.000     -1       Died
## 106             0              0      <NA>       0.000     -1       Died
## 107             0              0      <NA>       0.000     -1       Died
## 108             0              0      <NA>       0.000     -1       Died
## 109           132            133       Age      17.500      1       <NA>
## 110             0              0      <NA>       0.000     -1       Died
## 111             0              0      <NA>       0.000     -1       Died
## 112             0              0      <NA>       0.000     -1       Died
## 113             0              0      <NA>       0.000     -1       Died
## 114           134            135       Age      26.500      1       <NA>
## 115           136            137       Age      28.250      1       <NA>
## 116             0              0      <NA>       0.000     -1   Survived
## 117           138            139       Age      30.750      1       <NA>
## 118             0              0      <NA>       0.000     -1       Died
## 119           140            141       Age      37.500      1       <NA>
## 120             0              0      <NA>       0.000     -1   Survived
## 121           142            143       Age      47.500      1       <NA>
## 122             0              0      <NA>       0.000     -1   Survived
## 123             0              0      <NA>       0.000     -1       Died
## 124           144            145       Age      15.500      1       <NA>
## 125             0              0      <NA>       0.000     -1   Survived
## 126             0              0      <NA>       0.000     -1       Died
## 127           146            147       Age      23.500      1       <NA>
## 128             0              0      <NA>       0.000     -1   Survived
## 129             0              0      <NA>       0.000     -1   Survived
## 130           148            149       Age      38.500      1       <NA>
## 131             0              0      <NA>       0.000     -1   Survived
## 132             0              0      <NA>       0.000     -1       Died
## 133           150            151       Age      18.500      1       <NA>
## 134             0              0      <NA>       0.000     -1       Died
## 135             0              0      <NA>       0.000     -1       Died
## 136             0              0      <NA>       0.000     -1       Died
## 137             0              0      <NA>       0.000     -1       Died
## 138           152            153       Age      30.250      1       <NA>
## 139           154            155       Age      31.500      1       <NA>
## 140           156            157       Age      36.750      1       <NA>
## 141           158            159       Age      38.500      1       <NA>
## 142           160            161       Age      45.250      1       <NA>
## 143             0              0      <NA>       0.000     -1       Died
## 144             0              0      <NA>       0.000     -1   Survived
## 145           162            163       Age      17.500      1       <NA>
## 146           164            165       Age      22.500      1       <NA>
## 147           166            167       Age      24.500      1       <NA>
## 148             0              0      <NA>       0.000     -1   Survived
## 149             0              0      <NA>       0.000     -1   Survived
## 150             0              0      <NA>       0.000     -1       Died
## 151             0              0      <NA>       0.000     -1       Died
## 152             0              0      <NA>       0.000     -1       Died
## 153             0              0      <NA>       0.000     -1       Died
## 154             0              0      <NA>       0.000     -1       Died
## 155             0              0      <NA>       0.000     -1       Died
## 156           168            169       Age      35.500      1       <NA>
## 157             0              0      <NA>       0.000     -1       Died
## 158             0              0      <NA>       0.000     -1       Died
## 159           170            171       Age      39.500      1       <NA>
## 160           172            173       Age      44.500      1       <NA>
## 161             0              0      <NA>       0.000     -1       Died
## 162           174            175       Age      16.500      1       <NA>
## 163             0              0      <NA>       0.000     -1   Survived
## 164           176            177       Age      21.500      1       <NA>
## 165             0              0      <NA>       0.000     -1   Survived
## 166             0              0      <NA>       0.000     -1   Survived
## 167             0              0      <NA>       0.000     -1   Survived
## 168           178            179       Age      34.500      1       <NA>
## 169           180            181       Age      36.250      1       <NA>
## 170             0              0      <NA>       0.000     -1       Died
## 171             0              0      <NA>       0.000     -1       Died
## 172           182            183       Age      43.500      1       <NA>
## 173             0              0      <NA>       0.000     -1       Died
## 174             0              0      <NA>       0.000     -1   Survived
## 175             0              0      <NA>       0.000     -1   Survived
## 176             0              0      <NA>       0.000     -1   Survived
## 177             0              0      <NA>       0.000     -1   Survived
## 178             0              0      <NA>       0.000     -1       Died
## 179             0              0      <NA>       0.000     -1       Died
## 180             0              0      <NA>       0.000     -1       Died
## 181             0              0      <NA>       0.000     -1       Died
## 182             0              0      <NA>       0.000     -1       Died
## 183             0              0      <NA>       0.000     -1       Died

Variable importance

Exercise: depression and voting

Resampling methods

  • Evaluating model fit/predictive power
  • How to avoid overfitting the data

Validation set

  • Randomly split data into two distinct sets
    • Training set
    • Test set
  • Train model on training set
  • Evaluate fit on test set

Regression

Mean squared error

\[MSE = \frac{1}{n} \sum_{i = 1}^{n}{(y_i - \hat{f}(x_i))^2}\]

  • \(y_i =\) the observed response value for the \(i\)th observation
  • \(\hat{f}(x_i) =\) the predicted response value for the \(i\)th observation given by \(\hat{f}\)
  • \(n =\) the total number of observations

Split data

set.seed(1234)

auto_split <- initial_split(data = Auto, prop = 0.5)
auto_train <- training(auto_split)
auto_test <- testing(auto_split)

Train model

auto_lm <- glm(mpg ~ horsepower, data = auto_train)
summary(auto_lm)
## 
## Call:
## glm(formula = mpg ~ horsepower, data = auto_train)
## 
## Deviance Residuals: 
##      Min        1Q    Median        3Q       Max  
## -13.7105   -3.4442   -0.5342    2.6256   15.1015  
## 
## Coefficients:
##              Estimate Std. Error t value Pr(>|t|)    
## (Intercept) 40.057910   1.054798   37.98   <2e-16 ***
## horsepower  -0.157604   0.009402  -16.76   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for gaussian family taken to be 24.80151)
## 
##     Null deviance: 11780.6  on 195  degrees of freedom
## Residual deviance:  4811.5  on 194  degrees of freedom
## AIC: 1189.6
## 
## Number of Fisher Scoring iterations: 2
(train_mse <- augment(auto_lm, newdata = auto_train) %>%
  mutate(.resid = mpg - .fitted,
         .resid2 = .resid ^ 2) %$%
  mean(.resid2))
## [1] 24.54843

Test model

(test_mse <- augment(auto_lm, newdata = auto_test) %>%
  mutate(.resid = mpg - .fitted,
         .resid2 = .resid ^ 2) %$%
  mean(.resid2))
## [1] 23.38243

Compare models

Classification

survive_age_woman_x <- glm(Survived ~ Age * Sex, data = titanic,
                           family = binomial)
summary(survive_age_woman_x)
## 
## Call:
## glm(formula = Survived ~ Age * Sex, family = binomial, data = titanic)
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -1.9401  -0.7136  -0.5883   0.7626   2.2455  
## 
## Coefficients:
##             Estimate Std. Error z value Pr(>|z|)   
## (Intercept)  0.59380    0.31032   1.913  0.05569 . 
## Age          0.01970    0.01057   1.863  0.06240 . 
## Sexmale     -1.31775    0.40842  -3.226  0.00125 **
## Age:Sexmale -0.04112    0.01355  -3.034  0.00241 **
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 964.52  on 713  degrees of freedom
## Residual deviance: 740.40  on 710  degrees of freedom
##   (177 observations deleted due to missingness)
## AIC: 748.4
## 
## Number of Fisher Scoring iterations: 4

Test error rate

# split the data into training and validation sets
titanic_split <- initial_split(data = titanic, prop = 0.5)

# fit model to training data
train_model <- glm(Survived ~ Age * Sex,
                   data = training(titanic_split),
                   family = binomial)
summary(train_model)
## 
## Call:
## glm(formula = Survived ~ Age * Sex, family = binomial, data = training(titanic_split))
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -1.9374  -0.7041  -0.5866   0.7644   2.1918  
## 
## Coefficients:
##             Estimate Std. Error z value Pr(>|z|)  
## (Intercept)  0.58906    0.41752   1.411   0.1583  
## Age          0.01968    0.01414   1.391   0.1642  
## Sexmale     -1.42528    0.55970  -2.546   0.0109 *
## Age:Sexmale -0.03806    0.01829  -2.080   0.0375 *
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 485.10  on 358  degrees of freedom
## Residual deviance: 370.14  on 355  degrees of freedom
##   (87 observations deleted due to missingness)
## AIC: 378.14
## 
## Number of Fisher Scoring iterations: 4
# calculate predictions using validation set
x_test_accuracy <- augment(train_model,
                           newdata = testing(titanic_split)) %>% 
  as_tibble() %>%
  mutate(pred = logit2prob(.fitted),
         pred = as.numeric(pred > .5))

# calculate test error rate
mean(x_test_accuracy$Survived != x_test_accuracy$pred, na.rm = TRUE)
## [1] 0.2225352

Drawbacks to validation sets

Leave-one-out cross-validation

\[CV_{(n)} = \frac{1}{n} \sum_{i = 1}^{n}{MSE_i}\]

  • Extension of validation set to repeatedly split data and average results
  • Minimizes bias of estimated error rate
  • Low variance
  • Highly computationally intensive

rsample::loo_cv()

loocv_data <- loo_cv(Auto)
loocv_data
## # Leave-one-out cross-validation 
## # A tibble: 392 x 2
##    splits       id        
##    <list>       <chr>     
##  1 <S3: rsplit> Resample1 
##  2 <S3: rsplit> Resample2 
##  3 <S3: rsplit> Resample3 
##  4 <S3: rsplit> Resample4 
##  5 <S3: rsplit> Resample5 
##  6 <S3: rsplit> Resample6 
##  7 <S3: rsplit> Resample7 
##  8 <S3: rsplit> Resample8 
##  9 <S3: rsplit> Resample9 
## 10 <S3: rsplit> Resample10
## # ... with 382 more rows

Splits

first_resample <- loocv_data$splits[[1]]
first_resample
## <391/1/392>
training(first_resample)
## # A tibble: 391 x 9
##      mpg cylinders displacement horsepower weight acceleration  year origin
##    <dbl>     <dbl>        <dbl>      <dbl>  <dbl>        <dbl> <dbl>  <dbl>
##  1    18         8          307        130   3504         12      70      1
##  2    15         8          350        165   3693         11.5    70      1
##  3    18         8          318        150   3436         11      70      1
##  4    16         8          304        150   3433         12      70      1
##  5    17         8          302        140   3449         10.5    70      1
##  6    15         8          429        198   4341         10      70      1
##  7    14         8          454        220   4354          9      70      1
##  8    14         8          440        215   4312          8.5    70      1
##  9    14         8          455        225   4425         10      70      1
## 10    15         8          390        190   3850          8.5    70      1
## # ... with 381 more rows, and 1 more variable: name <fct>
assessment(first_resample)
## # A tibble: 1 x 9
##     mpg cylinders displacement horsepower weight acceleration  year origin
##   <dbl>     <dbl>        <dbl>      <dbl>  <dbl>        <dbl> <dbl>  <dbl>
## 1    14         8          318        150   4457         13.5    74      1
## # ... with 1 more variable: name <fct>

Holdout results

  1. Obtain the analysis data set (i.e. the \(n-1\) training set)
  2. Fit a linear regression model
  3. Predict the assessment data set using the broom package
  4. Determine the MSE for each sample

Holdout results

holdout_results <- function(splits) {
  # Fit the model to the n-1
  mod <- glm(mpg ~ horsepower, data = analysis(splits))
  
  # Save the heldout observation
  holdout <- assessment(splits)
  
  # `augment` will save the predictions with the holdout data set
  res <- augment(mod, newdata = holdout) %>%
    # calculate residuals for future use
    mutate(.resid = mpg - .fitted)
  
  # Return the assessment data set with the additional columns
  res
}

Holdout results

holdout_results(loocv_data$splits[[1]])
## # A tibble: 1 x 12
##     mpg cylinders displacement horsepower weight acceleration  year origin
##   <dbl>     <dbl>        <dbl>      <dbl>  <dbl>        <dbl> <dbl>  <dbl>
## 1    14         8          318        150   4457         13.5    74      1
## # ... with 4 more variables: name <fct>, .fitted <dbl>, .se.fit <dbl>,
## #   .resid <dbl>
loocv_data$results <- map(loocv_data$splits, holdout_results)
loocv_data$mse <- map_dbl(loocv_data$results, ~ mean(.$.resid ^ 2))
loocv_data
## # Leave-one-out cross-validation 
## # A tibble: 392 x 4
##    splits       id         results               mse
##    <list>       <chr>      <list>              <dbl>
##  1 <S3: rsplit> Resample1  <tibble [1 × 12]>  5.17  
##  2 <S3: rsplit> Resample2  <tibble [1 × 12]>  1.77  
##  3 <S3: rsplit> Resample3  <tibble [1 × 12]>  2.07  
##  4 <S3: rsplit> Resample4  <tibble [1 × 12]>  2.40  
##  5 <S3: rsplit> Resample5  <tibble [1 × 12]> 14.8   
##  6 <S3: rsplit> Resample6  <tibble [1 × 12]>  2.77  
##  7 <S3: rsplit> Resample7  <tibble [1 × 12]> 56.9   
##  8 <S3: rsplit> Resample8  <tibble [1 × 12]> 22.6   
##  9 <S3: rsplit> Resample9  <tibble [1 × 12]>  0.0680
## 10 <S3: rsplit> Resample10 <tibble [1 × 12]> 50.1   
## # ... with 382 more rows
loocv_data %>%
  summarize(mse = mean(mse))
## # Leave-one-out cross-validation 
## # A tibble: 1 x 1
##     mse
##   <dbl>
## 1  24.2

Compare polynomial terms

LOOCV in classification

# function to generate assessment statistics for titanic model
holdout_results <- function(splits) {
  # Fit the model to the n-1
  mod <- glm(Survived ~ Age * Sex, data = analysis(splits),
             family = binomial)
  
  # Save the heldout observation
  holdout <- assessment(splits)
  
  # `augment` will save the predictions with the holdout data set
  res <- augment(mod, newdata = assessment(splits)) %>% 
    as_tibble() %>%
    mutate(pred = logit2prob(.fitted),
           pred = as.numeric(pred > .5))

  # Return the assessment data set with the additional columns
  res
}

titanic_loocv <- loo_cv(titanic) %>%
  mutate(results = map(splits, holdout_results),
         error_rate = map_dbl(results, ~ mean(.$Survived != .$pred,
                                              na.rm = TRUE)))
mean(titanic_loocv$error_rate, na.rm = TRUE)
## [1] 0.219888

Exercise: LOOCV in linear regression

\(k\)-fold cross-validation

\[CV_{(k)} = \frac{1}{k} \sum_{i = 1}^{k}{MSE_i}\]

  • Split data into \(k\) folds
  • Repeat training/test process for each fold
  • LOOCV: \(k=n\)

k-fold CV in linear regression

# modified function to estimate model with varying highest order polynomial
holdout_results <- function(splits, i) {
  # Fit the model to the training set
  mod <- glm(mpg ~ poly(horsepower, i), data = analysis(splits))
  
  # Save the heldout observations
  holdout <- assessment(splits)
  
  # `augment` will save the predictions with the holdout data set
  res <- augment(mod, newdata = holdout) %>%
    # calculate residuals for future use
    mutate(.resid = mpg - .fitted)
  
  # Return the assessment data set with the additional columns
  res
}

# function to return MSE for a specific higher-order polynomial term
poly_mse <- function(i, vfold_data){
  vfold_mod <- vfold_data %>%
    mutate(results = map(splits, holdout_results, i),
           mse = map_dbl(results, ~ mean(.$.resid ^ 2)))
  
  mean(vfold_mod$mse)
}

# split Auto into 10 folds
auto_cv10 <- vfold_cv(data = Auto, v = 10)

cv_mse <- data_frame(terms = seq(from = 1, to = 5),
                     mse_vfold = map_dbl(terms, poly_mse, auto_cv10))
cv_mse
## # A tibble: 5 x 2
##   terms mse_vfold
##   <int>     <dbl>
## 1     1      24.2
## 2     2      19.2
## 3     3      19.3
## 4     4      19.3
## 5     5      18.9

Computational speed of LOOCV

Computational speed of 10-fold CV

k-fold CV in logistic regression

# function to generate assessment statistics for titanic model
holdout_results <- function(splits) {
  # Fit the model to the training set
  mod <- glm(Survived ~ Age * Sex, data = analysis(splits),
             family = binomial)
  
  # Save the heldout observations
  holdout <- assessment(splits)
  
  # `augment` will save the predictions with the holdout data set
  res <- augment(mod, newdata = assessment(splits)) %>% 
    as_tibble() %>%
    mutate(pred = logit2prob(.fitted),
           pred = as.numeric(pred > .5))

  # Return the assessment data set with the additional columns
  res
}

titanic_cv10 <- vfold_cv(data = titanic, v = 10) %>%
  mutate(results = map(splits, holdout_results),
         error_rate = map_dbl(results, ~ mean(.$Survived != .$pred,
                                              na.rm = TRUE)))
mean(titanic_cv10$error_rate, na.rm = TRUE)
## [1] 0.2209604

Exercise: k-fold CV